{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": [],
      "name": "inverse_rendering.ipynb"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4WSXmzOVBn-X"
      },
      "source": [
        "##### Copyright 2021 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "gkJ16-cKBuAK"
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hGbNmaA5_si_"
      },
      "source": [
        "# Inverse Rendering\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/inverse_rendering.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/inverse_rendering.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VvLMHI88DBVs"
      },
      "source": [
        "This notebook demonstrates an optimization that approximates an image of a 3D shape under unknown camera and lighting using differentiable rendering functions. The variables of optimization include: **camera rotation**, **position**, and **field-of-view**, **lighting direction**, and **background color**. \n",
        "\n",
        "Because the TFG rendering does not include global illumination effects such as shadows, the output rendering will not perfectly match the input shape. To overcome this issue, we use a robust loss based on the [structured similarity metric](https://www.tensorflow.org/api_docs/python/tf/image/ssim).\n",
        "\n",
        "As demonstrated here, accurate derivatives at occlusion boundaries are critical for the optimization to succeed. TensorFlow Graphics implements the **rasterize-then-splat** algorithm [Cole, et al., 2021] to produce derivatives at occlusions. Rasterization with no special treatment of occlusions is provided for comparison; without handling occlusion boundaries, the optimization diverges. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ppRKISWUQIeB"
      },
      "source": [
        "## Setup Notebook"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pZs6dzmQsdY6",
        "cellView": "form"
      },
      "source": [
        "%%capture\n",
        "#@title Install TensorFlow Graphics\n",
        "%pip install tensorflow_graphics"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "__t3mrMftAA2"
      },
      "source": [
        "#@title Fetch the model and target image\n",
        "!wget -N https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/spot.zip\n",
        "!unzip -o spot.zip\n",
        "!wget -N https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/spot.png"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "3H__1-brS0ms"
      },
      "source": [
        "#@title Import modules\n",
        "import math\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import matplotlib.pyplot as plt\n",
        "from PIL import Image as PilImage\n",
        "import tempfile\n",
        "\n",
        "import tensorflow_graphics.geometry.transformation.quaternion as quat\n",
        "import tensorflow_graphics.geometry.transformation.euler as euler\n",
        "import tensorflow_graphics.geometry.transformation.look_at as look_at\n",
        "import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as rotation_matrix_3d\n",
        "from tensorflow_graphics.rendering.camera import perspective\n",
        "from tensorflow_graphics.rendering import triangle_rasterizer\n",
        "from tensorflow_graphics.rendering import splat\n",
        "\n",
        "from tensorflow_graphics.rendering.texture import texture_map\n",
        "from tensorflow_graphics.geometry.representation.mesh import normals as normals_module"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2TebacgwQKeG"
      },
      "source": [
        "## Load the Spot model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "U98pmFE_OWtn"
      },
      "source": [
        "#@title Load the mesh and texture\n",
        "def load_and_flatten_obj(obj_path):\n",
        "  \"\"\"Loads an .obj and flattens the vertex lists into a single array.\n",
        "\n",
        "  .obj files may contain separate lists of positions, texture coordinates, and\n",
        "  normals. In this case, a triangle vertex will have three values: indices into\n",
        "  each of the position, texture, and normal lists. This function flattens those\n",
        "  lists into a single vertex array by looking for unique combinations of\n",
        "  position, texture, and normal, adding those to list, and then reindexing the\n",
        "  triangles.\n",
        "\n",
        "  This function processes only 'v', 'vt', 'vn', and 'f' .obj lines.\n",
        "\n",
        "  Args:\n",
        "    obj_path: the path to the Wavefront .obj file.\n",
        "\n",
        "  Returns:\n",
        "    a numpy array of vertices and a Mx3 numpy array of triangle indices.\n",
        "\n",
        "    The vertex array will have shape Nx3, Nx5, Nx6, or Nx8, depending on whether\n",
        "    position, position + texture, position + normals, or\n",
        "    position + texture + normals are present.\n",
        "\n",
        "    Unlike .obj, the triangle vertex indices are 0-based.\n",
        "  \"\"\"\n",
        "  VERTEX_TYPES = ['v', 'vt', 'vn']\n",
        "\n",
        "  vertex_lists = {n: [] for n in VERTEX_TYPES}\n",
        "  flat_vertices_list = []\n",
        "  flat_vertices_indices = {}\n",
        "  flat_triangles = []\n",
        "  # Keep track of encountered vertex types.\n",
        "  has_type = {t: False for t in VERTEX_TYPES}\n",
        "\n",
        "  with open(obj_path) as obj_file:\n",
        "    for line in iter(obj_file):\n",
        "      tokens = line.split()\n",
        "      if not tokens:\n",
        "        continue\n",
        "      line_type = tokens[0]\n",
        "      # We skip lines not starting with v, vt, vn, or f.\n",
        "      if line_type in VERTEX_TYPES:\n",
        "        vertex_lists[line_type].append([float(x) for x in tokens[1:]])\n",
        "      elif line_type == 'f':\n",
        "        triangle = []\n",
        "        for i in range(3):\n",
        "          # The vertex name is one of the form: 'v', 'v/vt', 'v//vn', or\n",
        "          # 'v/vt/vn'.\n",
        "          vertex_name = tokens[i + 1]\n",
        "          if vertex_name in flat_vertices_indices:\n",
        "            triangle.append(flat_vertices_indices[vertex_name])\n",
        "            continue\n",
        "          # Extract all vertex type indices ('' for unspecified).\n",
        "          vertex_indices = vertex_name.split('/')\n",
        "          while len(vertex_indices) < 3:\n",
        "            vertex_indices.append('')\n",
        "          flat_vertex = []\n",
        "          for vertex_type, index in zip(VERTEX_TYPES, vertex_indices):\n",
        "            if index:\n",
        "              # obj triangle indices are 1 indexed, so subtract 1 here.\n",
        "              flat_vertex += vertex_lists[vertex_type][int(index) - 1]\n",
        "              has_type[vertex_type] = True\n",
        "            else:\n",
        "              # Append zeros for missing attributes.\n",
        "              flat_vertex += [0, 0] if vertex_type == 'vt' else [0, 0, 0]\n",
        "          flat_vertex_index = len(flat_vertices_list)\n",
        "\n",
        "          flat_vertices_list.append(flat_vertex)\n",
        "          flat_vertices_indices[vertex_name] = flat_vertex_index\n",
        "          triangle.append(flat_vertex_index)\n",
        "        flat_triangles.append(triangle)\n",
        "\n",
        "  # Keep only vertex types that are used in at least one vertex.\n",
        "  flat_vertices_array = np.float32(flat_vertices_list)\n",
        "  flat_vertices = flat_vertices_array[:, :3]\n",
        "  if has_type['vt']:\n",
        "    flat_vertices = np.concatenate((flat_vertices, flat_vertices_array[:, 3:5]),\n",
        "                                   axis=-1)\n",
        "  if has_type['vn']:\n",
        "    flat_vertices = np.concatenate((flat_vertices, flat_vertices_array[:, -3:]),\n",
        "                                   axis=-1)\n",
        "\n",
        "  return flat_vertices, np.int32(flat_triangles)\n",
        "\n",
        "def load_texture(texture_filename):\n",
        "  \"\"\"Returns a texture image loaded from a file (float32 in [0,1] range).\"\"\"\n",
        "  with open(texture_filename, 'rb') as f:\n",
        "    return np.asarray(PilImage.open(f)).astype(np.float32) / 255.0\n",
        "\n",
        "spot_texture_map = load_texture('spot/spot_texture.png')\n",
        "\n",
        "vertices, triangles = load_and_flatten_obj('spot/spot_triangulated.obj')\n",
        "vertices, uv_coords = tf.split(vertices, (3,2), axis=-1)\n",
        "normals = normals_module.vertex_normals(vertices, triangles)\n",
        "print(vertices.shape, uv_coords.shape, normals.shape, triangles.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "WHDaIoh7RPSA"
      },
      "source": [
        "#@title Load and display target image\n",
        "from PIL import Image as PilImage\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def show_image(image, show=True):\n",
        "  plt.imshow(image, origin='lower')\n",
        "  plt.axis('off')\n",
        "  if show:\n",
        "    plt.show()\n",
        "\n",
        "with open('spot.png', 'rb') as target_file:\n",
        "  target_image = PilImage.open(target_file)\n",
        "  target_image.thumbnail([200,200])\n",
        "  target_image = np.array(target_image).astype(np.float32) / 255.0\n",
        "  target_image = np.flipud(target_image)\n",
        "\n",
        "image_width = target_image.shape[1]\n",
        "image_height = target_image.shape[0]\n",
        "\n",
        "show_image(target_image)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FrVYSAlpQoM1"
      },
      "source": [
        "## Set up rendering functions and variables"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TDDSf5YQSXVV",
        "cellView": "form"
      },
      "source": [
        "#@title Initial variables\n",
        "import math\n",
        "\n",
        "def make_initial_variables():\n",
        "  camera_translation = tf.Variable([[0.0, 0.0, -4]])\n",
        "  fov = tf.Variable([40.0 * math.pi / 180.0])\n",
        "  quaternion = tf.Variable(tf.expand_dims(\n",
        "      quat.from_euler((0.0, 0.0, 0.0)), axis=0))\n",
        "  background_color = tf.Variable([1.0, 1.0, 1.0, 1.0])\n",
        "  light_direction = tf.Variable([0.5, 0.5, 1.0])\n",
        "  return {\n",
        "      'quaternion': quaternion,\n",
        "      'translation': camera_translation,\n",
        "      'fov': fov,\n",
        "      'background_color': background_color,\n",
        "      'light_direction': light_direction\n",
        "  }\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JEAeBBhvTAez",
        "cellView": "form"
      },
      "source": [
        "#@title Rendering functions\n",
        "\n",
        "\n",
        "def shade(rasterized, light_direction, ka=0.5, kd=0.5):\n",
        "  \"\"\"Shades the input rasterized buffer using a basic illumination model.\n",
        "\n",
        "  Args:\n",
        "    rasterized: a dictionary of interpolated attribute buffers.\n",
        "    light_direction: a vector defining the direction of a single light.\n",
        "    ka: ambient lighting coefficient.\n",
        "    kd: diffuse lighting coefficient.\n",
        "\n",
        "  Returns:\n",
        "    an RGBA buffer of shaded pixels.\n",
        "  \"\"\"\n",
        "  textured = texture_map.map_texture(rasterized['uv'][tf.newaxis, ...],\n",
        "                                     spot_texture_map)[0, ...]\n",
        "\n",
        "  light_direction = tf.reshape(light_direction, [1, 1, 3])\n",
        "  light_direction = tf.math.l2_normalize(light_direction, axis=-1)\n",
        "  n_dot_l = tf.clip_by_value(\n",
        "      tf.reduce_sum(\n",
        "          rasterized['normals'] * light_direction, axis=2, keepdims=True), 0.0,\n",
        "      1.0)\n",
        "  ambient = textured * ka\n",
        "  diffuse = textured * kd * n_dot_l\n",
        "  lit = ambient + diffuse\n",
        "\n",
        "  lit_rgba = tf.concat((lit, rasterized['mask']), -1)\n",
        "  return lit_rgba\n",
        "\n",
        "\n",
        "def rasterize_without_splatting(projection, image_width, image_height,\n",
        "                                light_direction):\n",
        "  rasterized = triangle_rasterizer.rasterize(vertices, triangles, {\n",
        "      'uv': uv_coords,\n",
        "      'normals': normals\n",
        "  }, projection, (image_height, image_width))\n",
        "\n",
        "  lit = shade(rasterized, light_direction)\n",
        "  return lit\n",
        "\n",
        "\n",
        "def rasterize_then_splat(projection, image_width, image_height,\n",
        "                         light_direction):\n",
        "  return splat.rasterize_then_splat(vertices, triangles, {\n",
        "      'uv': uv_coords,\n",
        "      'normals': normals\n",
        "  }, projection, (image_height, image_width),\n",
        "                                    lambda d: shade(d, light_direction))\n",
        "\n",
        "\n",
        "def render_forward(variables, rasterization_func):\n",
        "  camera_translation = variables['translation']\n",
        "  eye = camera_translation\n",
        "  # Place the \"center\" of the scene along the Z axis from the camera.\n",
        "  center = tf.constant([[0.0, 0.0, 1.0]]) + camera_translation\n",
        "  world_up = tf.constant([[0.0, 1.0, 0.0]])\n",
        "\n",
        "  normalized_quaternion = variables['quaternion'] / tf.norm(\n",
        "      variables['quaternion'], axis=1, keepdims=True)\n",
        "  model_rotation_3x3 = rotation_matrix_3d.from_quaternion(normalized_quaternion)\n",
        "  model_rotation_4x4 = tf.pad(model_rotation_3x3 - tf.eye(3),\n",
        "                              ((0, 0), (0, 1), (0, 1))) + tf.eye(4)\n",
        "\n",
        "  look_at_4x4 = look_at.right_handed(eye, center, world_up)\n",
        "  perspective_4x4 = perspective.right_handed(variables['fov'],\n",
        "                                             (image_width / image_height,),\n",
        "                                             (0.01,), (10.0,))\n",
        "\n",
        "  projection = tf.matmul(perspective_4x4,\n",
        "                         tf.matmul(look_at_4x4, model_rotation_4x4))\n",
        "\n",
        "  rendered = rasterization_func(projection, image_width, image_height,\n",
        "                                variables['light_direction'])\n",
        "\n",
        "  background_rgba = variables['background_color']\n",
        "  background_rgba = tf.tile(\n",
        "      tf.reshape(background_rgba, [1, 1, 4]), [image_height, image_width, 1])\n",
        "  composited = rendered + background_rgba * (1.0 - rendered[..., 3:4])\n",
        "  return composited"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "9UU1JD7HPKLv"
      },
      "source": [
        "#@title Loss function\n",
        "def ssim_loss(target, rendered):\n",
        "  target_yuv = tf.compat.v2.image.rgb_to_yuv(target[..., :3])\n",
        "  rendered_yuv = tf.compat.v2.image.rgb_to_yuv(rendered[..., :3])\n",
        "  ssim = tf.compat.v2.image.ssim(target_yuv, rendered_yuv, max_val=1.0)\n",
        "  return 1.0 - ssim"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "0yYdpz-hPOLL"
      },
      "source": [
        "#@title Backwards pass\n",
        "@tf.function\n",
        "def render_grad(target, variables, rasterization_func):\n",
        "  with tf.GradientTape() as g:\n",
        "    rendered = render_forward(variables, rasterization_func)\n",
        "    loss_value = ssim_loss(target, rendered)\n",
        "  grads = g.gradient(loss_value, variables)\n",
        "  return rendered, grads, loss_value"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lB9Xq96vO3us"
      },
      "source": [
        "## Run optimization"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "vdcE4k8ZhaOT"
      },
      "source": [
        "#@title Run gradient descent\n",
        "variables = make_initial_variables()\n",
        "\n",
        "# Change this to rasterize to test without RtS\n",
        "rasterization_mode = 'rasterize then splat'  #@param [ \"rasterize then splat\", \"rasterize without splatting\"]\n",
        "rasterization_func = (\n",
        "    rasterize_then_splat\n",
        "    if rasterization_mode == 'rasterize then splat' else rasterize_without_splatting)\n",
        "\n",
        "learning_rate = 0.02 #@param {type: \"slider\", min: 0.002, max: 0.05, step: 0.002}\n",
        "start = render_forward(variables, rasterization_func)\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate)\n",
        "animation_images = [start.numpy()]\n",
        "num_steps = 300 #@param { type: \"slider\", min: 100, max: 2000, step: 100}\n",
        "\n",
        "for i in range(num_steps):\n",
        "  current, grads, loss = render_grad(target_image, variables, rasterization_func)\n",
        "  to_apply = [(grads[k], variables[k]) for k in variables.keys()]\n",
        "  optimizer.apply_gradients(to_apply)\n",
        "  if i > 0 and i % 10 == 0:\n",
        "    animation_images.append(current.numpy())\n",
        "  if i % 100 == 0:\n",
        "    print('Loss at step {:03d}: {:.3f}'.format(i, loss.numpy()))\n",
        "    pass\n",
        "print('Final loss {:03d}: {:.3f}'.format(i, loss.numpy()))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "dkThT6Oshf-9"
      },
      "source": [
        "#@title Display results\n",
        "plt.figure(figsize=[18,6])\n",
        "plt.subplot(1,4,1)\n",
        "plt.title('Initialization')\n",
        "show_image(np.clip(start, 0.0, 1.0), show=False)\n",
        "plt.subplot(1,4,2)\n",
        "plt.title('After Optimization')\n",
        "show_image(np.clip(current, 0.0, 1.0), show=False)\n",
        "plt.subplot(1,4,3)\n",
        "plt.title('Target')\n",
        "show_image(target_image, show=False)\n",
        "plt.subplot(1,4,4)\n",
        "plt.title('Difference')\n",
        "show_image(current[...,0] - target_image[...,0])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "aQwoCByjhh44"
      },
      "source": [
        "%%capture\n",
        "#@title Display animation\n",
        "import matplotlib.animation as animation\n",
        "\n",
        "def save_animation(images):\n",
        "  fig = plt.figure(figsize=(8, 8))\n",
        "  plt.axis('off')\n",
        "  ims = [[plt.imshow(np.flipud(np.clip(i, 0.0, 1.0)))] for i in images]\n",
        "  return animation.ArtistAnimation(fig, ims, interval=50, blit=True)\n",
        "\n",
        "anim = save_animation(animation_images)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LTG19FACaTTK"
      },
      "source": [
        "from IPython.display import HTML\n",
        "HTML(anim.to_jshtml())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "r2ibcdm1hkiC"
      },
      "source": [
        "#@title Display initial and optimized camera parameters\n",
        "def print_camera_params(v):\n",
        "  print(f\"FoV (degrees): {v['fov'].numpy() * 180.0 / math.pi}\")\n",
        "  print(f\"Position: {v['translation'].numpy()}\")\n",
        "  print(f\"Orientation (xyz angles): {euler.from_quaternion(v['quaternion']).numpy()}\")\n",
        "\n",
        "print(\"INITIAL CAMERA:\")\n",
        "print_camera_params(make_initial_variables())\n",
        "print(\"\\nOPTIMIZED CAMERA:\")\n",
        "print_camera_params(variables)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}